import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import mean_squared_error, mean_absolute_error
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


def data_prep(ds='chen_lsoda002', n_in=8, output_window=1, batch_size=16, trn_val=(0.7, 0.2)):
    
    ### read data
    data = pd.read_csv(f'{ds}.csv', index_col=0) # lorenz_lsoda002

    ### hyper-parameters setting for data prep
    trn_r, val_r = trn_val
    timesteps = n_in #10 # 16 # 20
    prediction_horizon = output_window #1 #
    n_timeseries = data.shape[1] # 序列条数，变量个数
    bs = batch_size #16 # 


    ##### data prep for multi-vars #####
    seq_len = len(data)
    train_length = int(seq_len*trn_r)
    val_length = int(seq_len*val_r)
    test_length = seq_len - train_length - val_length

    ### Windowing: time series data -> TS dataset (X, y) for supervised learning
    X = np.zeros((seq_len, timesteps, n_timeseries))
    y = np.zeros((seq_len, prediction_horizon, n_timeseries))

    for i, name in enumerate(list(data.columns)):
        for j in range(timesteps):
            X[:, j, i] = data[name].shift(timesteps - j - 1).fillna(method="bfill")
        for k in range(prediction_horizon):
            y[:, k, i] = data[name].shift(-prediction_horizon).fillna(method="ffill").values

    X = X[timesteps:]
    y = y[timesteps:]

    ### train, val, test split of X, y, target
    X_train = X[:train_length]
    y_train = y[:train_length]
    X_val = X[train_length:train_length+val_length]
    y_val = y[train_length:train_length+val_length]
    X_test = X[-test_length:]
    y_test = y[-test_length:]

    ### min-max scale transform of X, y, target
    X_train_max = X_train.max(axis=0)
    X_train_min = X_train.min(axis=0)
    y_train_max = y_train.max(axis=0)
    y_train_min = y_train.min(axis=0)

    X_train = (X_train - X_train_min) / (X_train_max - X_train_min)
    X_val = (X_val - X_train_min) / (X_train_max - X_train_min)
    X_test = (X_test - X_train_min) / (X_train_max - X_train_min)

    y_train = (y_train - y_train_min) / (y_train_max - y_train_min)
    y_val = (y_val - y_train_min) / (y_train_max - y_train_min)
    y_test = (y_test - y_train_min) / (y_train_max - y_train_min)

    ### 2 torch
    X_train_t = torch.Tensor(X_train)
    X_val_t = torch.Tensor(X_val)
    X_test_t = torch.Tensor(X_test)
    y_train_t = torch.Tensor(y_train)
    y_val_t = torch.Tensor(y_val)
    y_test_t = torch.Tensor(y_test)

    ### Batching: dataset -> dataloader (dl)
    dl_trn = DataLoader(TensorDataset(X_train_t, y_train_t), shuffle=True, batch_size=bs)
    dl_val = DataLoader(TensorDataset(X_val_t, y_val_t), shuffle=False, batch_size=bs)
    dl_test = DataLoader(TensorDataset(X_test_t, y_test_t), shuffle=False, batch_size=bs)

    dl_val_rec = DataLoader(TensorDataset(X_val_t, y_val_t), shuffle=False, batch_size=1)
    dl_test_rec = DataLoader(TensorDataset(X_test_t, y_test_t), shuffle=False, batch_size=1)

    return (y_train_min, y_train_max), (dl_trn, dl_val, dl_test), (dl_val_rec, dl_test_rec)

def run_experiment(
    trn_minmax, dls, model,
    model_name,
    model_file,
    loss=nn.MSELoss(),
    lr=0.0015, #0.001,
    epochs=300, # 150, 250, 500
    lr_renew_epochs=100,
    lr_schedule_steps=10,
    patience=25, # 10, 30, 50
    must_train=False,
    val_again=True,
    verbose=0,
):
    y_train_min, y_train_max = trn_minmax
    n_timeseries = y_train_max.shape[1]
    dl_trn, dl_val, dl_test = dls
    
    try:
        model.load_state_dict(torch.load(model_file)) # 载入已有模型
        train = False
    except IOError: # FileNotFoundError, PermissionError
        train = True
    except RuntimeError: # Error(s) in loading state_dict for ...
        train = True # 模型设置不一、重新训练

    def validate(data_loader):
        with torch.no_grad(): # validating or testing
            mse_val = 0
            preds = []
            true = []
            for batch_x, batch_y in data_loader:
                batch_y = batch_y.cuda()
                output = model(batch_x.cuda())
                output = output.squeeze(1)
                preds.append(output.detach().cpu().numpy())
                true.append(batch_y.detach().cpu().numpy())
                mse_val += loss(output, batch_y.squeeze()).item()*batch_x.shape[0] #.squeeze()
        return np.concatenate(preds).squeeze(), np.squeeze(np.concatenate(true)), mse_val

    if train or must_train:
        opt = torch.optim.Adam(model.parameters(), lr=lr)
        # epoch_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=opt, step_size=20, gamma=0.9, last_epoch = 100)

        ### model training, validating and saving the best ones
        min_val_loss = 9999
        counter = 0
        lr_renew_counter = 0
        for j in range(epochs):
            if (lr_renew_counter % lr_renew_epochs == 0):
                epoch_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=opt, step_size=lr_schedule_steps, gamma=0.9)
            mse_train = 0
            for batch_x, batch_y in dl_trn: #data_train_loader
                batch_x = batch_x.cuda()
                batch_y = batch_y.cuda()
                opt.zero_grad()
                y_pred = model(batch_x)
                y_pred = y_pred.squeeze(1)
                l = loss(y_pred, batch_y.squeeze()) #.squeeze()
                l.backward()
                mse_train += l.item()*batch_x.shape[0]
                opt.step()
            epoch_scheduler.step()
            lr_renew_counter += 1
            
            # validating
            preds, true, mse_val = validate(dl_val) #data_val_loader
            
            rmse_val = mse_val**0.5
            val_loss_diff = rmse_val - min_val_loss
            if np.abs(val_loss_diff) > 1e-4*rmse_val and val_loss_diff < 0: # saving 
            # 只是小了一丁点的，也不更新
                min_val_loss = rmse_val
                print(f"Saving... with min_val_loss = {min_val_loss:.4f}, at epoch {j}") # ,end=''
                torch.save(model.state_dict(), model_file) # 只保存训练好的权重
                counter = 0
            else: 
                counter += 1

            if counter == patience:
                print(f'Early stop at epoch {j}, with train MSE {(mse_train):.4f} and val MSE: {(mse_val):.4f}')
                if lr_renew_counter < epochs:
                    lr_renew_counter = ((lr_renew_counter // lr_renew_epochs) + 1) * lr_renew_epochs
                    min_val_loss = 99 # reset 
                    continue
                else:
                    break
            elif j==epochs-1:
                print(f'{j+1} epochs finished with train MSE {(mse_train):.4f} and val MSE: {(mse_val):.4f}')
            if verbose > 1: #==2 # 冗长的；啰嗦的；唠叨的
                print(f'Epoch:{j:3d}, train MSE: {(mse_train):.4f}, val MSE: {(mse_val):.4f}')
                if (j % 10 == 0):
                    preds = preds*(y_train_max[:] - y_train_min[:]) + y_train_min[:]
                    true = true*(y_train_max[:] - y_train_min[:]) + y_train_min[:]
                    fig, axs = plt.subplots(n_timeseries, 1, sharex=True, figsize=(12,2*n_timeseries), dpi=85)
                    fig.subplots_adjust(hspace=0) # Remove horizontal space between axes
                    rmse = list(range(n_timeseries))
                    mae = list(range(n_timeseries))
                    for i in range(n_timeseries): # 验证集
                        rmse[i] = np.sqrt(mean_squared_error(true[:,i], preds[:,i]))
                        mae[i] = mean_absolute_error(true[:,i], preds[:,i])
                        axs[i].plot(true[:,i], label='Ground Truth')
                        axs[i].plot(preds[:,i], label=model_name) # Predicted
                        axs[i].legend()
                    rmse_ = [f'{x:.3f}' for x in rmse]
                    mae_ = [f'{x:.3f}' for x in mae]
                    print(f'valid RMSE: {rmse_}, MAE: {mae_}')
                    plt.show()

    print(f'==> Load the saved {model_name} model ...')
    model.load_state_dict(torch.load(model_file)) # 载入最优模型
    
    ### model testing
    preds2, true2, mse_val2 = validate(dl_test) #data_test_loader

    ### inverse min-max scale transform
    preds2 = preds2*(y_train_max[:] - y_train_min[:]) + y_train_min[:]
    true2 = true2*(y_train_max[:] - y_train_min[:]) + y_train_min[:]
    
    ### evaluating some metrics
    rmse2 = list(range(n_timeseries))
    mae2 = list(range(n_timeseries))
    for i in range(n_timeseries):
        rmse2[i] = np.sqrt(mean_squared_error(true2[:,i], preds2[:,i]))
        mae2[i] = mean_absolute_error(true2[:,i], preds2[:,i])
    rmse2_ = [f'{x:.3f}' for x in rmse2]
    mae2_ = [f'{x:.3f}' for x in mae2]
    print(f"{'x'*56}\n tset RMSE: {rmse2_}, MAE: {mae2_}")
    
    if val_again: # 再一次，验证集
        preds1, true1, mse_val1 = validate(dl_val) #data_val_loader
        ### inverse min-max scale transform
        preds1 = preds1*(y_train_max[:] - y_train_min[:]) + y_train_min[:]
        true1 = true1*(y_train_max[:] - y_train_min[:]) + y_train_min[:]

        ### evaluating some metrics
        rmse1 = list(range(n_timeseries))
        mae1 = list(range(n_timeseries))
        for i in range(n_timeseries):
            rmse1[i] = np.sqrt(mean_squared_error(true1[:,i], preds1[:,i]))
            mae1[i] = mean_absolute_error(true1[:,i], preds1[:,i])
        rmse1_ = [f'{x:.3f}' for x in rmse1]
        mae1_ = [f'{x:.3f}' for x in mae1]
        print(f'valid RMSE: {rmse1_}, MAE: {mae1_}')
        rmse12 = [f'{x/y:.3f}' for x,y in zip(rmse1,rmse2)]
        mae12 = [f'{x/y:.3f}' for x,y in zip(mae1,mae2)]
        print(f'valid/test RMSE: {rmse12}, MAE: {mae12}')
        
        return np.average([rmse1, rmse2])

    if verbose > 0: # ==1 or 2
        fig2, axs2 = plt.subplots(n_timeseries, 1, sharex=True, figsize=(12,2*n_timeseries), dpi=85)
        fig2.subplots_adjust(hspace=0)
        for i in range(n_timeseries):
            axs2[i].plot(true2[:,i], label='Ground Truth')
            axs2[i].plot(preds2[:,i], label=model_name) # Predicted
            axs2[i].grid(ls='--')
            axs2[i].legend()
        plt.show()
    
    return np.average(rmse2)


# model = RNN(name='LSTM', input_length=timesteps, target_length=prediction_horizon, hidden_layer_size=100, num_layers=1, bidirectional=False, input_size=n_timeseries, target_size=n_timeseries).cuda()
# run_experiment(model, model_name='LSTM', model_file='lstm_m_lorenz0225.pt', epochs=250, patience=35) # , must_train=True, patience=20



##### REC test n times #####
dt = 0.02
fontsize = 16
font = {'family': 'serif', 'size': fontsize}
error_bound = 1.5
linecolors1 = ['#457b9d','#fbc53e','#a71e34'] # 蓝，橙/黄，红
linecolors2 = ['#25abe2','#f99f1b','#ee446f']


def run_experiment_rec_n(
    trn_minmax, dls_rec, model,
    model_name,
    model_file,
    N_plot=201,
    N_rec=100, # 递归多步预测的测试次数
    N_slide=3, # 相邻两次测试的步数差
    val_again=True,
):
    y_train_min, y_train_max = trn_minmax
    # n_timeseries = y_train_max.shape[1]
    dl_val_rec, dl_test_rec = dls_rec
    for batch_x, batch_y in dl_test_rec:
        break
    prediction_horizon, n_timeseries = batch_y.size(1), batch_y.size(2)

    try:
        print(f'==> Load the saved {model_name} model ...')
        model.load_state_dict(torch.load(model_file)) # 载入已有模型
    except IOError: # FileNotFoundError, PermissionError
        print("model file not found, \n please use `run_experiment` to train first. ")
        return

    def validate_rec_n(data_loader_rec, N_start):
        with torch.no_grad(): # validating or testing
            preds = []
            true = []
            N_end = N_start + N_plot
            for k, (batch_x, batch_y) in enumerate(data_loader_rec):
                if k < N_start:
                    continue # 跳过当前继续执行下一个循环
                    # pass # 什么都不操作，接着循环
                elif k == N_start:
                    batch_x_rec = batch_x[:]
                elif k >= N_end:
                    break # 直接中断循环，不再执行
                output = model(batch_x_rec.cuda())
                a = batch_x_rec[:,prediction_horizon:,:].clone()
                batch_x_rec[:,:-prediction_horizon,:] = a
                batch_x_rec[:,-prediction_horizon:,:] = output
                preds.append(output.squeeze(1).detach().cpu().numpy())
                true.append(batch_y.numpy())
        return np.concatenate(preds), np.squeeze(np.concatenate(true))

    def evaluate_rec_n(true_, preds_, linecolors, s='test'):
        ### inverse min-max scale transform
        preds_ = [preds0*(y_train_max[:] - y_train_min[:]) + y_train_min[:] for preds0 in preds_]
        true_ = [true0*(y_train_max[:] - y_train_min[:]) + y_train_min[:] for true0 in true_]

        # N_rec_ = len(preds_) # 实际的测试次数
        ae_ = [np.abs(preds0[:]-true0[:]) for (preds0,true0) in zip(preds_,true_)]
        ae_n = np.average(ae_, axis=0) # 绝对误差的平均
        cae_n = np.cumsum(ae_n, axis=0)
        eff_pred_len_n = list(range(n_timeseries))

        preds0, true0 = preds_[-1], true_[-1] # 默认是最后一条作为样例
        N_plot0 = len(preds0)
        ae = np.abs(preds0[:]-true0[:]) # 绝对误差
        cae = np.cumsum(ae, axis=0) # 累计绝对误差
        eff_pred_len = list(range(n_timeseries))

        fig, axs = plt.subplots(2*n_timeseries, 1, sharex=True, figsize=(10,2.5*n_timeseries), dpi=85)
        fig.subplots_adjust(hspace=0) # Remove horizontal space between axes
        for i in range(n_timeseries):
            cae_ = cae[:,i]
            cae_i = cae_n[:,i]
            eff_pred_len[i] = len(cae_[np.where(cae_ < error_bound)])
            eff_pred_len_n[i] = len(cae_i[np.where(cae_i < error_bound)])
            axs[2*i].plot(np.array(range(N_plot0))*dt, true0[:,i], color=linecolors[0], label='Ground Truth')
            axs[2*i].plot(np.array(range(N_plot0))*dt, preds0[:,i], color=linecolors[1], label=model_name) # 
            axs[2*i].axvline(x=eff_pred_len[i]*dt, color="black", linestyle="--") # grey
            axs[2*i].grid(ls='-')
            axs[2*i].legend(loc="upper right")
            #(loc="upper left", bbox_to_anchor=(1,1)) # place legend in top right corner, legend()
            axs[2*i+1].plot(np.array(range(N_plot0))*dt, ae_n[:,i], color=linecolors[2], label='Averaged L1 Error') 
            axs[2*i+1].axhline(y=error_bound, color="black", linestyle="--")
            axs[2*i+1].grid(ls='-')
            axs[2*i+1].legend()
        axs[2*i+1].set_xlabel(r"Time $t$", **font)
        print(f"{model_name} gives a {*eff_pred_len,}-step effective prediction in {s} set \n{*eff_pred_len_n,}-step effective prediction on average")
        plt.show()
        return eff_pred_len_n # epl

    if val_again: # 验证集
        preds1 = list(range(N_rec))
        true1 = list(range(N_rec))
        for j in range(N_rec):
            preds1[j], true1[j] = validate_rec_n(dl_val_rec, j*N_slide) #data_val_loader_rec

        ### evaluating
        epl_val = evaluate_rec_n(true1, preds1, linecolors1, 'valid')
    
    ### model testing
    preds2 = list(range(N_rec))
    true2 = list(range(N_rec))
    for j in range(N_rec):
        preds2[j], true2[j] = validate_rec_n(dl_test_rec, j*N_slide) #data_test_loader_rec

    ### evaluating
    epl_test = evaluate_rec_n(true2, preds2, linecolors2, 'test')

    return epl_val, epl_test

# run_experiment_rec_n(model, model_name='LSTM', model_file='lstm_m_lorenz0225.pt', N_plot=151, N_rec=100, N_slide=4)
